import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from torch import Tensor

from models import VariableMLP


class TrivialPredictor(nn.Module):
    def __init__(self, Z_dim=None, X_dim=None):
        super().__init__()
        self.Z_dim=Z_dim
        self.X_dim=X_dim
        self.pred=nn.Parameter(torch.tensor(0.5))

    def eval_seq(self, Zkwargs, X, Y, return_preds=False):
        # predict 1/2 everywhere
        p_hat_pred = torch.ones_like(Y) * self.pred
        loss_matrix = nn.functional.binary_cross_entropy(p_hat_pred, Y, 
                                                  reduction='none')
        if return_preds:
            return loss_matrix, p_hat_pred
        return loss_matrix


class MarginalPredictorContext(nn.Module):
    
    def __init__(self, bert_encoder=None, category_args=None, X_dim=None, Z_dim=None, MLP_layer=6, MLP_width=50, rand_prior=False):
        super().__init__()
        self.use_bert = False
        self.use_category = False
        self.MLP_width = MLP_width
        self.MLP_layer = MLP_layer
        self.rand_prior = rand_prior

        if bert_encoder is not None:
            self.z_encoder = BertEncoder(bert_encoder)
            self.z_encoder_output_dim = self.z_encoder.output_dim()
            self.use_bert = True
        elif category_args is not None:
            self.z_encoder = nn.Embedding(num_embeddings=category_args['num_embeddings'],
                                          embedding_dim=category_args['embedding_dim'])
            nn.init.xavier_uniform_(self.z_encoder.weight)
            self.z_encoder_output_dim = category_args['embedding_dim']
            self.use_category = True
        elif Z_dim is not None or X_dim is not None:
            self.z_encoder = lambda x: x
            self.z_encoder_output_dim = Z_dim
        else:
            raise ValueError("Need a Z feature")

        self.X_dim = X_dim

        self.top_layer = VariableMLP(input_dim=self.z_encoder_output_dim+self.X_dim,
                             num_layers=MLP_layer, width=MLP_width, rand_prior=rand_prior)
    

    def forward(self, Zkwargs, X):
        embed = self.z_encoder(Zkwargs)
        
        bs, ncol, xdim = X.shape
        bsZ, zdim = embed.shape
        assert bs == bsZ
        embedZ = embed.unsqueeze(1).repeat((1,ncol,1))

        input_ = torch.cat([embedZ,X], 2)
        return self.top_layer(input_).squeeze(2)

    def get_features(self, Zkwargs, X):
        embed = self.z_encoder(Zkwargs)
        
        bs, ncol, xdim = X.shape
        bsZ, zdim = embed.shape
        assert bs == bsZ
        embedZ = embed.unsqueeze(1).repeat((1,ncol,1))

        input_ = torch.cat([embedZ,X], 2)

        partial_top_layer = nn.Sequential(*list(self.top_layer.model.children())[:5])
        features_of_X = partial_top_layer(input_)
        return features_of_X

    def eval_seq(self, Zkwargs, X, Y, N=None, return_preds=False, trainlen=None, exact=False):
        if N is None:
            N = Y.shape[1]
        N = min(N, Y.shape[1])
        
        # not a sequence model (marginal loss)
        p_hat_pred = self.forward(Zkwargs, X)
        # true success_p and p_hat_pred have dimensions (rows, N)
        loss_matrix = nn.functional.binary_cross_entropy(p_hat_pred, Y[:,:N], 
                                                  reduction='none')
        if return_preds:
            return loss_matrix, p_hat_pred
        return loss_matrix


class SequentialPredictorContext(nn.Module):
    def __init__(self, bert_encoder=None, category_args=None, X_dim=None, Z_dim=None, MLP_layer=6, 
                 MLP_width=50, init_mean=0, repeat_suffstat=100):
        super().__init__()
        self.use_bert = False
        self.use_category = False
        self.init_mean = init_mean
        self.MLP_layer = MLP_layer
        self.MLP_width = MLP_width
        self.repeat_suffstat = repeat_suffstat
        self.X_dim = X_dim

        if bert_encoder is not None:
            self.z_encoder = BertEncoder(bert_encoder)
            self.z_encoder_output_dim = self.z_encoder.output_dim()
            self.use_bert = True

        elif category_args is not None:
            self.z_encoder = nn.Embedding(num_embeddings=category_args['num_embeddings'],
                                          embedding_dim=category_args['embedding_dim'])
            nn.init.xavier_uniform_(self.z_encoder.weight)
            self.z_encoder_output_dim = category_args['embedding_dim']
            self.use_category = True

        elif Z_dim is not None:
            self.z_encoder = lambda x: x
            self.z_encoder_output_dim = Z_dim

        else:
            # There are no Z features
            self.z_encoder = lambda x: None
            self.z_encoder_output_dim = 0

        # size of sufficitn statistic
        self.suffStatDim = X_dim + X_dim**2
        self.top_layer = VariableMLP(input_dim=self.z_encoder_output_dim + self.X_dim + self.suffStatDim*repeat_suffstat,
                             num_layers=MLP_layer, width=MLP_width)


    def get_device(self):
        return next(self.top_layer.parameters()).device

    
    def init_model_states(self, batch_size):
        mean = torch.ones(self.X_dim).to(self.get_device())*self.init_mean
        cov = torch.eye(self.X_dim).to(self.get_device())
        init_state = torch.cat([mean, cov.flatten()])
        return init_state.repeat((batch_size,1)).repeat_interleave(self.repeat_suffstat, dim=1).to(self.get_device())

    
    def get_state(self, X, Y):
        if Y.shape[1] == 0:
            return self.init_model_states(X.shape[0])

        inv_cov = torch.eye(self.X_dim).to(self.get_device()) + torch.einsum('ijk,ijl->ikl', X, X)
        cov = torch.linalg.inv(inv_cov)

        bs = X.shape[0]
        prior_mean = torch.ones(self.X_dim).to(self.get_device())*self.init_mean
        mean_num = prior_mean.repeat((bs,1)) + torch.sum( X*Y.unsqueeze(2), axis=1 )
        mean = torch.einsum('ijk,ik->ij', cov, mean_num)

        state = torch.cat([mean, cov.flatten(start_dim=1, end_dim=2)], axis=1)
        return state.repeat_interleave(self.repeat_suffstat, dim=1).to(self.get_device())

        
    # no mask as input, use mask later. 
    def eval_seq(self, Zkwargs, X, Y, true_p=None, N=None, return_preds=False, trainlen=None, exact=False):
        # TODO exact functionality is not implemented
        
        if N is None:
            N = Y.shape[1]
        N = min(N, Y.shape[1])
        if trainlen is None: trainlen = N
        encoded_Z = self.z_encoder(Zkwargs)

        bs, zdim = encoded_Z.shape
        embedZ = encoded_Z.unsqueeze(1).repeat((1,N,1))

        # Get all states
        all_states = []
        for j in range(N):
            prev_Ys = Y[:,:j]
            prev_Xs = X[:,:j]
            curr_state = self.get_state(prev_Xs, prev_Ys)
            all_states.append(curr_state)

        all_states_cat = torch.cat([x.unsqueeze(1) for x in all_states], axis=1)
        input_ = torch.cat([embedZ, X, all_states_cat], 2) 
        
        p_hat_pred = self.top_layer(input_).squeeze(2)
        loss_matrix = nn.functional.binary_cross_entropy(p_hat_pred, Y[:,:N], 
                                                  reduction='none')
        if return_preds:
            return loss_matrix, p_hat_pred

        return loss_matrix

    def fill_table_naive(self, Z, hist_X, hist_Y, hist_mask, eval_X):
        m = eval_X.shape[1]
        assert hist_Y.shape[1] == hist_X.shape[1] # number of timesteps observed
        Y = torch.zeros(hist_Y.shape[0], hist_Y.shape[1] + m)
        X = torch.zeros(hist_X.shape[0], hist_Y.shape[1] + m, hist_X.shape[2])
        Y[:,:hist_Y.shape[1]] = hist_Y
        X[:,:hist_Y.shape[1],:] = hist_X

        for j in range(m):
            curr_state = self.get_state(X, Y)
            input_ = torch.cat([Z.unsqueeze(1), eval_X[:,[j],:], curr_state.unsqueeze(1)], 2)

            # Generate new Y's, for this j
            p_hat_pred = self.top_layer(input_).squeeze(2)
            new_cur_Ys = torch.bernoulli(p_hat_pred)

            # Update relevant parts of Y with the generated Y's
            Y[:,j + hist_Y.shape[1]] = new_cur_Ys[:,0]
            X[:,j + hist_X.shape[1]] = eval_X[:,j]

        return Y[:,hist_Y.shape[1]:]


    def fill_table_naive_finite(self, Z, hist_X, hist_Y, hist_mask, eval_X):
        T = hist_mask.shape[1]
        imputedY = torch.clone(hist_Y)
        imputedX = torch.clone(eval_X)
        curr_mask = torch.clone(hist_mask)

        for t in range(T):
                
            curr_mask_exp = curr_mask.reshape( hist_mask.shape[0], hist_mask.shape[1], 1).repeat((1,1,self.X_dim))
            curr_state = self.get_state(imputedX*curr_mask_exp, imputedY*curr_mask)

            #import ipdb; ipdb.set_trace()

            # Generate new Y's, for this column / t
            input_ = torch.cat([Z.unsqueeze(1), imputedX[:,[t],:], curr_state.unsqueeze(1)], 2)
            p_hat_pred = self.top_layer(input_).squeeze(2)
            gen_Ys = torch.bernoulli(p_hat_pred)
            
            # Update relevant parts of Y with the generated Y's
            imputedY[:,t] = imputedY[:,t] * curr_mask[:,t] + gen_Ys[:,0] * (1-curr_mask[:,t])
            curr_mask[:,t] = 1
            
        return imputedY
